{
 "cells": [
  {
   "cell_type": "raw",
   "id": "df7d42b9-58a6-434c-a2d7-0b61142f6d3e",
   "metadata": {},
   "source": [
    "---\n",
    "sidebar_position: 5\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2195672-0cab-4967-ba8a-c6544635547d",
   "metadata": {},
   "source": [
    "# How to handle multiple retrievers when doing query analysis\n",
    "\n",
    "Sometimes, a query analysis technique may allow for selection of which retriever to use. To use this, you will need to add some logic to select the retriever to do. We will show a simple example (using mock data) of how to do that."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4079b57-4369-49c9-b2ad-c809b5408d7e",
   "metadata": {},
   "source": [
    "## Setup\n",
    "#### Install dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e168ef5c-e54e-49a6-8552-5502854a6f01",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install -qU langchain langchain-community langchain-openai langchain-chroma"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79d66a45-a05c-4d22-b011-b1cdbdfc8f9c",
   "metadata": {},
   "source": [
    "#### Set environment variables\n",
    "\n",
    "We'll use OpenAI in this example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40e2979e-a818-4b96-ac25-039336f94319",
   "metadata": {},
   "outputs": [],
   "source": [
    "import getpass\n",
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass()\n",
    "\n",
    "# Optional, uncomment to trace runs with LangSmith. Sign up here: https://smith.langchain.com.\n",
    "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
    "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c20b48b8-16d7-4089-bc17-f2d240b3935a",
   "metadata": {},
   "source": [
    "### Create Index\n",
    "\n",
    "We will create a vectorstore over fake information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1f621694",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_chroma import Chroma\n",
    "from langchain_openai import OpenAIEmbeddings\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "\n",
    "texts = [\"Harrison worked at Kensho\"]\n",
    "embeddings = OpenAIEmbeddings(model=\"text-embedding-3-small\")\n",
    "vectorstore = Chroma.from_texts(texts, embeddings, collection_name=\"harrison\")\n",
    "retriever_harrison = vectorstore.as_retriever(search_kwargs={\"k\": 1})\n",
    "\n",
    "texts = [\"Ankush worked at Facebook\"]\n",
    "embeddings = OpenAIEmbeddings(model=\"text-embedding-3-small\")\n",
    "vectorstore = Chroma.from_texts(texts, embeddings, collection_name=\"ankush\")\n",
    "retriever_ankush = vectorstore.as_retriever(search_kwargs={\"k\": 1})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57396e23-c192-4d97-846b-5eacea4d6b8d",
   "metadata": {},
   "source": [
    "## Query analysis\n",
    "\n",
    "We will use function calling to structure the output. We will let it return multiple queries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0b51dd76-820d-41a4-98c8-893f6fe0d1ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Optional\n",
    "\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "\n",
    "\n",
    "class Search(BaseModel):\n",
    "    \"\"\"Search for information about a person.\"\"\"\n",
    "\n",
    "    query: str = Field(\n",
    "        ...,\n",
    "        description=\"Query to look up\",\n",
    "    )\n",
    "    person: str = Field(\n",
    "        ...,\n",
    "        description=\"Person to look things up for. Should be `HARRISON` or `ANKUSH`.\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "783c03c3-8c72-4f88-9cf4-5829ce6745d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "output_parser = PydanticToolsParser(tools=[Search])\n",
    "\n",
    "system = \"\"\"You have the ability to issue search queries to get information to help answer user information.\"\"\"\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"system\", system),\n",
    "        (\"human\", \"{question}\"),\n",
    "    ]\n",
    ")\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
    "structured_llm = llm.with_structured_output(Search)\n",
    "query_analyzer = {\"question\": RunnablePassthrough()} | prompt | structured_llm"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9564078",
   "metadata": {},
   "source": [
    "We can see that this allows for routing between retrievers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "bc1d3863",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Search(query='workplace', person='HARRISON')"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_analyzer.invoke(\"where did Harrison Work\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "af62af17-4f90-4dbd-a8b4-dfff51f1db95",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Search(query='workplace', person='ANKUSH')"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "query_analyzer.invoke(\"where did ankush Work\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7c65b2f-7881-45fc-a47b-a4eaaf48245f",
   "metadata": {},
   "source": [
    "## Retrieval with query analysis\n",
    "\n",
    "So how would we include this in a chain? We just need some simple logic to select the retriever and pass in the search query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "1e047d87",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import chain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4ec0c7fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "retrievers = {\n",
    "    \"HARRISON\": retriever_harrison,\n",
    "    \"ANKUSH\": retriever_ankush,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "8dac7866",
   "metadata": {},
   "outputs": [],
   "source": [
    "@chain\n",
    "def custom_chain(question):\n",
    "    response = query_analyzer.invoke(question)\n",
    "    retriever = retrievers[response.person]\n",
    "    return retriever.invoke(response.query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "232ad8a7-7990-4066-9228-d35a555f7293",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Document(page_content='Harrison worked at Kensho')]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_chain.invoke(\"where did Harrison Work\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "28e14ba5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Document(page_content='Ankush worked at Facebook')]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "custom_chain.invoke(\"where did ankush Work\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33338d4f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
